import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # Reshape the input tensor from (batch_size, 1, 1, 228) to (batch_size, 228)
        x = x.squeeze(1)
        batch_size = x.size(0)
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).requires_grad_()
        h0 = h0.to("cpu")
        c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim).requires_grad_()
        c0 = c0.to("cpu")
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))

        # Reshape the output tensor from (sequence_length, batch_size, hidden_size) to (batch_size, hidden_size)
        out = out.squeeze(1)

        out = self.fc(out)

        # Reshape output tensor from (batch_size, output_dim) to (batch_size, 1, 1, output_dim)
        out = out.unsqueeze(1).unsqueeze(2).unsqueeze(3)
        out = out.squeeze(-1)  # Adjusting the last dimension of the output tensor to 1, removing the last dimension.
        return out


class align(nn.Module):
    def __init__(self, c_in, c_out):
        super(align, self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        if c_in > c_out:
            self.conv1x1 = nn.Conv2d(c_in, c_out, 1)

    def forward(self, x):
        if self.c_in > self.c_out:
            return self.conv1x1(x)
        if self.c_in < self.c_out:
            return F.pad(x, [0, 0, 0, 0, 0, self.c_out - self.c_in, 0, 0])
        return x


class temporal_conv_layer(nn.Module):
    def __init__(self, kt, c_in, c_out, act="relu"):
        super(temporal_conv_layer, self).__init__()
        self.kt = kt
        self.act = act
        self.c_out = c_out
        self.align = align(c_in, c_out)
        if self.act == "GLU":
            self.conv = nn.Conv2d(c_in, c_out * 2, (kt, 1), 1)
        else:
            self.conv = nn.Conv2d(c_in, c_out, (kt, 1), 1)

    def forward(self, x):
        x_in = self.align(x)[:, :, self.kt - 1:, :]
        if self.act == "GLU":
            x_conv = self.conv(x)
            return (x_conv[:, :self.c_out, :, :] + x_in) * torch.sigmoid(x_conv[:, self.c_out:, :, :])
        if self.act == "sigmoid":
            return torch.sigmoid(self.conv(x) + x_in)
        return torch.relu(self.conv(x) + x_in)


class spatio_conv_layer(nn.Module):
    def __init__(self, ks, c, Lk):
        super(spatio_conv_layer, self).__init__()
        self.Lk = Lk
        self.theta = nn.Parameter(torch.FloatTensor(c, c, ks))
        self.b = nn.Parameter(torch.FloatTensor(1, c, 1, 1))
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.theta, a=math.sqrt(5))
        fan_in, _ = init._calculate_fan_in_and_fan_out(self.theta)
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(self.b, -bound, bound)

    def forward(self, x):
        x_c = torch.einsum("knm,bitm->bitkn", self.Lk, x)
        x_gc = torch.einsum("iok,bitkn->botn", self.theta, x_c) + self.b
        return torch.relu(x_gc + x)


class st_conv_block(nn.Module):
    def __init__(self, ks, kt, n, c, p, Lk):
        super(st_conv_block, self).__init__()
        self.tconv1 = temporal_conv_layer(kt, c[0], c[1], "GLU")
        self.sconv = spatio_conv_layer(ks, c[1], Lk)
        self.tconv2 = temporal_conv_layer(kt, c[1], c[2])
        # self.ln = nn.LayerNorm([n, c[2]])
        self.ln = nn.BatchNorm2d(c[2])
        self.dropout = nn.Dropout(p)

    def forward(self, x):
        x_t1 = self.tconv1(x)
        x_s = self.sconv(x_t1)
        x_t2 = self.tconv2(x_s)
        x_ln = self.ln(x_t2)
        # x_ln = self.ln(x_t2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        return self.dropout(x_ln)


class fully_conv_layer(nn.Module):
    def __init__(self, c):
        super(fully_conv_layer, self).__init__()
        self.conv = nn.Conv2d(c, 1, 1)

    def forward(self, x):
        return self.conv(x)


class out_set_dim(nn.Module):
    def __init__(self):
        super(out_set_dim, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(0, 1)  # The second dimension has a size of 0 and is calculated dynamically later.

    def forward(self, x):
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        in_features = x.size(1)  # Dynamically calculating the size of the second dimension
        self.fc = nn.Linear(in_features, 1)  # Updating the input dimension
        x = self.fc(x)
        x = x.view(-1, 1, 1, 1)
        return x
class output_layer(nn.Module):
    def __init__(self, c, T, n):
        super(output_layer, self).__init__()
        self.tconv1 = temporal_conv_layer(T, c, c, "GLU")
        # self.ln = nn.LayerNorm([n, c])
        self.ln = nn.BatchNorm2d(c)
        self.tconv2 = temporal_conv_layer(1, c, c, "sigmoid")
        self.fc = fully_conv_layer(c)
    def forward(self, x):
        x_t1 = self.tconv1(x)
        x_ln = self.ln(x_t1)
        x_t2 = self.tconv2(x_ln)
        out = self.fc(x_t2)
        return out


class Transformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Transformer, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=input_dim, nhead=1, dim_feedforward=hidden_dim, batch_first=True),
            num_layers=1
        )
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        # Reshape input tensor from (batch_size, 1, 1, sequ) to (batch_size, sequ, 1)
        x = x.squeeze(1)

        out = self.transformer(x)

        # Reshape output tensor from (batch_size, sequ, input_dim) to (batch_size, input_dim)
        out = out.squeeze(1)

        out = self.fc(out)

        # Reshape output tensor from (batch_size, 1, 1, output_dim) to (batch_size, 1, 1, 1)
        out = out.unsqueeze(1).unsqueeze(2)

        return out

class GRUNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUNet, self).__init__()
        self.gru = nn.GRU(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, 3, -1).transpose(1, 2)  # torch.Size([100, 6, 3])
        out, h = self.gru(x)
        out = self.fc(out[:, -1, :])  # Take the output of the last moment
        out = out.unsqueeze(1).unsqueeze(2)  # torch.Size([100, 1, 1, 6])
        return out

class STGCN(nn.Module):
    def __init__(self, ks, kt, bs, T, n, Lk, p, shapes):
        super(STGCN, self).__init__()
        self.st_conv1 = st_conv_block(ks, kt, n, bs[0], p, Lk)
        self.st_conv2 = st_conv_block(ks, kt, n, bs[1], p, Lk)
        self.output = output_layer(bs[1][2], T - 4 * (kt - 1), n)

        input_dim = shapes
        hidden_dim = bs[0][2]
        num_layers = bs[0][0]
        output_dim = 1

        self.lstm = LSTM(input_dim, hidden_dim, num_layers, output_dim)
        self.transformer = Transformer(input_dim, hidden_dim, output_dim)
        self.gru = GRUNet(T, bs[0][2], shapes)
    def forward(self, x):
        x_st1 = self.st_conv1(x)
        x_st2 = self.st_conv2(x_st1)
        x_st3 = self.output(x_st2)
        x_gru = self.gru(x)
        out_1 = self.transformer(x_st3+x_gru)
        return out_1, x_st3+x_gru

class LSTM_STGCN(nn.Module):
    def __init__(self, ks, kt, bs, T, n, Lk, p, shapes, device):
        super(LSTM_STGCN, self).__init__()
        self.lstm = LSTM(shapes, bs[0][2], bs[0][0], 1)
        self.stgcn = STGCN(ks, kt, bs, T, n, Lk, p, shapes).to(device)
        self.bs = bs
    def forward(self, x):
        STGCN_Transformer, x_conncat = self.stgcn(x)
        output_lstm = self.lstm(x_conncat)
        output = STGCN_Transformer*2.1 + output_lstm*0.5
        return output